{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h1 align=\"center\">Advanced Registration</h1>\n",
    "\n",
    "\n",
    "**Summary:**\n",
    "1. SimpleITK provides two flavors of non-rigid registration:\n",
    "   * Free Form Deformation, BSpline based, and Demons using the ITKv4 registration framework.\n",
    "   * A set of Demons filters that are independent of the registration framework (`DemonsRegistrationFilter, DiffeomorphicDemonsRegistrationFilter, FastSymmetricForcesDemonsRegistrationFilter, SymmetricForcesDemonsRegistrationFilter`).\n",
    "2. Registration evaluation:\n",
    "   * Registration accuracy, the quantity of interest is the Target Registration Error (TRE).\n",
    "   * TRE is spatially variant.\n",
    "   * Surrogate metrics for evaluating registration accuracy such as segmentation overlaps are relevant, but are potentially deficient.\n",
    "   * Registration time.\n",
    "   * Acceptable values for TRE and runtime are context dependent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "import registration_gui as rgui\n",
    "import utilities \n",
    "\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "from ipywidgets import interact, fixed\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data and Registration Task\n",
    "\n",
    "In this notebook we will use the Point-validated Pixel-based Breathing Thorax Model (POPI). This is a 4D (3D+time) thoracic-abdominal CT (10 CTs representing the respiratory cycle) with masks segmenting each of the CTs to air/body/lung, and a set of corresponding landmarks localized in each of the CT volumes.\n",
    "\n",
    "The registration problem we deal with is non-rigid alignment of the lungs throughout the respiratory cycle. This information is relevant for radiation therapy planning and execution.\n",
    "\n",
    "\n",
    "The POPI model is provided by the Léon Bérard Cancer Center & CREATIS Laboratory, Lyon, France. The relevant publication is:\n",
    "\n",
    "J. Vandemeulebroucke, D. Sarrut, P. Clarysse, \"The POPI-model, a point-validated pixel-based breathing thorax model\",\n",
    "Proc. XVth International Conference on the Use of Computers in Radiation Therapy (ICCR), Toronto, Canada, 2007.\n",
    "\n",
    "Additional 4D CT data sets with reference points are available from the CREATIS Laboratory <a href=\"http://www.creatis.insa-lyon.fr/rio/popi-model?action=show&redirect=popi\">here</a>. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = []\n",
    "masks = []\n",
    "points = []\n",
    "image_indexes = [0,7]\n",
    "for i in image_indexes:\n",
    "    image_file_name = 'POPI/meta/{0}0-P.mhd'.format(i)\n",
    "    mask_file_name = 'POPI/masks/{0}0-air-body-lungs.mhd'.format(i)\n",
    "    points_file_name = 'POPI/landmarks/{0}0-Landmarks.pts'.format(i)\n",
    "    images.append(sitk.ReadImage(fdata(image_file_name), sitk.sitkFloat32)) \n",
    "    masks.append(sitk.ReadImage(fdata(mask_file_name)))\n",
    "    points.append(utilities.read_POPI_points(fdata(points_file_name)))\n",
    "        \n",
    "interact(rgui.display_coronal_with_overlay, temporal_slice=(0,len(images)-1), \n",
    "         coronal_slice = (0, images[0].GetSize()[1]-1), \n",
    "         images = fixed(images), masks = fixed(masks), \n",
    "         label=fixed(utilities.popi_lung_label), window_min = fixed(-1024), window_max=fixed(976));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Free Form Deformation\n",
    "\n",
    "Define a BSplineTransform using a sparse set of grid points overlaid onto the fixed image's domain to deform it.\n",
    "\n",
    "For the current registration task we are fortunate in that we have a unique setting. The images are of the same patient during respiration so we can initialize the registration using the identity transform. Additionally, we have masks demarcating the lungs.\n",
    "\n",
    "We use the registration framework taking advantage of its ability to use masks that limit the similarity metric estimation to points lying inside our region of interest, the lungs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_index = 0\n",
    "moving_index = 1\n",
    "\n",
    "fixed_image = images[fixed_index]\n",
    "fixed_image_mask = masks[fixed_index] == utilities.popi_lung_label\n",
    "fixed_points = points[fixed_index]\n",
    "\n",
    "moving_image = images[moving_index]\n",
    "moving_image_mask = masks[moving_index] == utilities.popi_lung_label\n",
    "moving_points = points[moving_index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a simple callback which allows us to monitor registration progress.\n",
    "def iteration_callback(filter):\n",
    "    print('\\r{0:.2f}'.format(filter.GetMetricValue()), end='')\n",
    "\n",
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "    \n",
    "# Determine the number of BSpline control points using the physical \n",
    "# spacing we want for the finest resolution control grid. \n",
    "grid_physical_spacing = [50.0, 50.0, 50.0] # A control point every 50mm\n",
    "image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]\n",
    "mesh_size = [int(image_size/grid_spacing + 0.5) \\\n",
    "             for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]\n",
    "# The starting mesh size will be 1/4 of the original, it will be refined by \n",
    "# the multi-resolution framework.\n",
    "mesh_size = [int(sz/4 + 0.5) for sz in mesh_size]\n",
    "\n",
    "initial_transform = sitk.BSplineTransformInitializer(image1 = fixed_image, \n",
    "                                                     transformDomainMeshSize = mesh_size, order=3)    \n",
    "# Instead of the standard SetInitialTransform we use the BSpline specific method which also\n",
    "# accepts the scaleFactors parameter to refine the BSpline mesh. In this case we start with \n",
    "# the given mesh_size at the highest pyramid level then we double it in the next lower level and\n",
    "# in the full resolution image we use a mesh that is four times the original size.\n",
    "registration_method.SetInitialTransformAsBSpline(initial_transform,\n",
    "                                                 inPlace=False,\n",
    "                                                 scaleFactors=[1,2,4])\n",
    "\n",
    "registration_method.SetMetricAsMeanSquares()\n",
    "registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "registration_method.SetMetricSamplingPercentage(0.01)\n",
    "registration_method.SetMetricFixedMask(fixed_image_mask)\n",
    "    \n",
    "registration_method.SetShrinkFactorsPerLevel(shrinkFactors = [4,2,1])\n",
    "registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2,1,0])\n",
    "registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()\n",
    "\n",
    "registration_method.SetInterpolator(sitk.sitkLinear)\n",
    "registration_method.SetOptimizerAsLBFGS2(solutionAccuracy=1e-2, numberOfIterations=100, deltaConvergenceTolerance=0.01)\n",
    "\n",
    "registration_method.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback(registration_method))\n",
    "\n",
    "final_transformation = registration_method.Execute(fixed_image, moving_image)\n",
    "print('\\nOptimizer\\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Qualitative evaluation via segmentation transfer\n",
    "\n",
    "Transfer the segmentation from the moving image to the fixed image before and after registration and visually evaluate overlap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_segmentation = sitk.Resample(moving_image_mask,\n",
    "                                         fixed_image,\n",
    "                                         final_transformation, \n",
    "                                         sitk.sitkNearestNeighbor,\n",
    "                                         0.0, \n",
    "                                         moving_image_mask.GetPixelID())\n",
    "\n",
    "interact(rgui.display_coronal_with_overlay, temporal_slice=(0,1), \n",
    "         coronal_slice = (0, fixed_image.GetSize()[1]-1), \n",
    "         images = fixed([fixed_image,fixed_image]), masks = fixed([moving_image_mask, transformed_segmentation]), \n",
    "         label=fixed(1), window_min = fixed(-1024), window_max=fixed(976));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantitative evaluation \n",
    "\n",
    "The most appropriate evaluation is based on analysis of Target Registration Errors(TRE), which is defined as follows:\n",
    "\n",
    "Given the transformation $T_f^m$ and corresponding points in the two coordinate systems, $^fp,^mp$, points which were not used in the registration process, TRE is defined as $\\|T_f^m(^fp) - ^mp\\|$. \n",
    "\n",
    "We start by looking at some descriptive statistics of TRE:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_TRE = utilities.target_registration_errors(sitk.Transform(), fixed_points, moving_points)\n",
    "final_TRE = utilities.target_registration_errors(final_transformation, fixed_points, moving_points)\n",
    "\n",
    "print('Initial alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(np.mean(initial_TRE), \n",
    "                                                                                               np.std(initial_TRE), \n",
    "                                                                                               np.max(initial_TRE)))\n",
    "print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(np.mean(final_TRE), \n",
    "                                                                                               np.std(final_TRE), \n",
    "                                                                                               np.max(final_TRE)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above descriptive statistics do not convey the whole picture, we should also look at the TRE distributions before and after registration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(initial_TRE, bins=20, alpha=0.5, label='before registration', color='blue')\n",
    "plt.hist(final_TRE, bins=20, alpha=0.5, label='after registration', color='green')\n",
    "plt.legend()\n",
    "plt.title('TRE histogram');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we should also take into account the fact that TRE is spatially variant (think center of rotation). We therefore should also explore the distribution of errors as a function of the point location."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_errors = utilities.target_registration_errors(sitk.Transform(), fixed_points, moving_points, display_errors = True)\n",
    "utilities.target_registration_errors(final_transformation, fixed_points, moving_points, \n",
    "                                     min_err=min(initial_errors), max_err=max(initial_errors), display_errors = True);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deciding whether a registration algorithm is appropriate for a specific problem is context dependent and is defined by the clinical/research needs both in terms of accuracy and computational complexity."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Demons Based Registration\n",
    "\n",
    "SimpleITK includes a number of filters from the Demons registration family (originally introduced by J. P. Thirion):\n",
    "1. DemonsRegistrationFilter.\n",
    "2. DiffeomorphicDemonsRegistrationFilter.\n",
    "3. FastSymmetricForcesDemonsRegistrationFilter.\n",
    "4. SymmetricForcesDemonsRegistrationFilter.\n",
    "\n",
    "These are appropriate for mono-modal registration. As these filters are independent of the ImageRegistrationMethod we do not have access to the multiscale framework. Luckily it is easy to implement our own multiscale framework in SimpleITK, which is what we do in the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def smooth_and_resample(image, shrink_factor, smoothing_sigma):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        image: The image we want to resample.\n",
    "        shrink_factor: A number greater than one, such that the new image's size is original_size/shrink_factor.\n",
    "        smoothing_sigma: Sigma for Gaussian smoothing, this is in physical (image spacing) units, not pixels.\n",
    "    Return:\n",
    "        Image which is a result of smoothing the input and then resampling it using the given sigma and shrink factor.\n",
    "    \"\"\"\n",
    "    smoothed_image = sitk.SmoothingRecursiveGaussian(image, smoothing_sigma)\n",
    "    \n",
    "    original_spacing = image.GetSpacing()\n",
    "    original_size = image.GetSize()\n",
    "    new_size = [int(sz/float(shrink_factor) + 0.5) for sz in original_size]\n",
    "    new_spacing = [((original_sz-1)*original_spc)/(new_sz-1) \n",
    "                   for original_sz, original_spc, new_sz in zip(original_size, original_spacing, new_size)]\n",
    "    return sitk.Resample(smoothed_image, new_size, sitk.Transform(), \n",
    "                         sitk.sitkLinear, image.GetOrigin(),\n",
    "                         new_spacing, image.GetDirection(), 0.0, \n",
    "                         image.GetPixelID())\n",
    "\n",
    "\n",
    "    \n",
    "def multiscale_demons(registration_algorithm,\n",
    "                      fixed_image, moving_image, initial_transform = None, \n",
    "                      shrink_factors=None, smoothing_sigmas=None):\n",
    "    \"\"\"\n",
    "    Run the given registration algorithm in a multiscale fashion. The original scale should not be given as input as the\n",
    "    original images are implicitly incorporated as the base of the pyramid.\n",
    "    Args:\n",
    "        registration_algorithm: Any registration algorithm that has an Execute(fixed_image, moving_image, displacement_field_image)\n",
    "                                method.\n",
    "        fixed_image: Resulting transformation maps points from this image's spatial domain to the moving image spatial domain.\n",
    "        moving_image: Resulting transformation maps points from the fixed_image's spatial domain to this image's spatial domain.\n",
    "        initial_transform: Any SimpleITK transform, used to initialize the displacement field.\n",
    "        shrink_factors: Shrink factors relative to the original image's size.\n",
    "        smoothing_sigmas: Amount of smoothing which is done prior to resmapling the image using the given shrink factor. These\n",
    "                          are in physical (image spacing) units.\n",
    "    Returns: \n",
    "        SimpleITK.DisplacementFieldTransform\n",
    "    \"\"\"\n",
    "    # Create image pyramid.\n",
    "    fixed_images = [fixed_image]\n",
    "    moving_images = [moving_image]\n",
    "    if shrink_factors:\n",
    "        for shrink_factor, smoothing_sigma in reversed(list(zip(shrink_factors, smoothing_sigmas))):\n",
    "            fixed_images.append(smooth_and_resample(fixed_images[0], shrink_factor, smoothing_sigma))\n",
    "            moving_images.append(smooth_and_resample(moving_images[0], shrink_factor, smoothing_sigma))\n",
    "    \n",
    "    # Create initial displacement field at lowest resolution. \n",
    "    # Currently, the pixel type is required to be sitkVectorFloat64 because of a constraint imposed by the Demons filters.\n",
    "    if initial_transform:\n",
    "        initial_displacement_field = sitk.TransformToDisplacementField(initial_transform, \n",
    "                                                                       sitk.sitkVectorFloat64,\n",
    "                                                                       fixed_images[-1].GetSize(),\n",
    "                                                                       fixed_images[-1].GetOrigin(),\n",
    "                                                                       fixed_images[-1].GetSpacing(),\n",
    "                                                                       fixed_images[-1].GetDirection())\n",
    "    else:\n",
    "        initial_displacement_field = sitk.Image(fixed_images[-1].GetWidth(), \n",
    "                                                fixed_images[-1].GetHeight(),\n",
    "                                                fixed_images[-1].GetDepth(),\n",
    "                                                sitk.sitkVectorFloat64)\n",
    "        initial_displacement_field.CopyInformation(fixed_images[-1])\n",
    " \n",
    "    # Run the registration.            \n",
    "    initial_displacement_field = registration_algorithm.Execute(fixed_images[-1], \n",
    "                                                                moving_images[-1], \n",
    "                                                                initial_displacement_field)\n",
    "    # Start at the top of the pyramid and work our way down.    \n",
    "    for f_image, m_image in reversed(list(zip(fixed_images[0:-1], moving_images[0:-1]))):\n",
    "            initial_displacement_field = sitk.Resample (initial_displacement_field, f_image)\n",
    "            initial_displacement_field = registration_algorithm.Execute(f_image, m_image, initial_displacement_field)\n",
    "    return sitk.DisplacementFieldTransform(initial_displacement_field)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use our newly minted multiscale framework to perform registration with the Demons filters. Some things you can easily try out by editing the code below:\n",
    "1. Is there really a need for multiscale - just call the multiscale_demons method without the shrink_factors and smoothing_sigmas parameters.\n",
    "2. Which Demons filter should you use - configure the other filters and see if our selection is the best choice (accuracy/time)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a simple callback which allows us to monitor registration progress.\n",
    "def iteration_callback(filter):\n",
    "    print('\\r{0}: {1:.2f}'.format(filter.GetElapsedIterations(), filter.GetMetric()), end='')\n",
    "    \n",
    "# Select a Demons filter and configure it.\n",
    "demons_filter =  sitk.FastSymmetricForcesDemonsRegistrationFilter()\n",
    "demons_filter.SetNumberOfIterations(20)\n",
    "# Regularization (update field - viscous, total field - elastic).\n",
    "demons_filter.SetSmoothDisplacementField(True)\n",
    "demons_filter.SetStandardDeviations(2.0)\n",
    "\n",
    "# Add our simple callback to the registration filter.\n",
    "demons_filter.AddCommand(sitk.sitkIterationEvent, lambda: iteration_callback(demons_filter))\n",
    "\n",
    "# Run the registration.\n",
    "tx = multiscale_demons(registration_algorithm=demons_filter, \n",
    "                       fixed_image = fixed_image, \n",
    "                       moving_image = moving_image,\n",
    "                       shrink_factors = [4,2],\n",
    "                       smoothing_sigmas = [8,4])\n",
    "\n",
    "# look at the final TREs.\n",
    "final_TRE = utilities.target_registration_errors(tx, fixed_points, moving_points, display_errors = True)\n",
    "\n",
    "print('Final alignment errors in millimeters, mean(std): {:.2f}({:.2f}), max: {:.2f}'.format(np.mean(final_TRE), \n",
    "                                                                                               np.std(final_TRE), \n",
    "                                                                                               np.max(final_TRE)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantitative Evaluation II (Segmentation)\n",
    "\n",
    "While the use of corresponding points to evaluate registration is the desired approach, it is often not applicable. In many cases there are only a few distinct points which can be localized in the two images, possibly too few to serve as a metric for evaluating the registration result across the whole region of interest. \n",
    "\n",
    "An alternative approach is to use segmentation. In this approach, we independently segment the structures of interest in the two images. After registration we transfer the segmentation from one image to the other and compare the original and registration induced segmentations.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transfer the segmentation via the estimated transformation. \n",
    "# Nearest Neighbor interpolation so we don't introduce new labels.\n",
    "transformed_labels = sitk.Resample(masks[moving_index],\n",
    "                                   fixed_image,\n",
    "                                   tx, \n",
    "                                   sitk.sitkNearestNeighbor,\n",
    "                                   0.0, \n",
    "                                   masks[moving_index].GetPixelID())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now replaced the task of evaluating registration with that of evaluating segmentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Often referred to as ground truth, but we prefer reference as the truth is never known.\n",
    "reference_segmentation = fixed_image_mask\n",
    "# Segmentations before and after registration\n",
    "segmentations = [moving_image_mask, transformed_labels == utilities.popi_lung_label]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from enum import Enum\n",
    "\n",
    "# Use enumerations to represent the various evaluation measures\n",
    "class OverlapMeasures(Enum):\n",
    "    jaccard, dice, volume_similarity, false_negative, false_positive = range(5)\n",
    "\n",
    "class SurfaceDistanceMeasures(Enum):\n",
    "    hausdorff_distance, mean_surface_distance, median_surface_distance, std_surface_distance, max_surface_distance = range(5)\n",
    "    \n",
    "# Empty numpy arrays to hold the results \n",
    "overlap_results = np.zeros((len(segmentations),len(OverlapMeasures.__members__.items())))  \n",
    "surface_distance_results = np.zeros((len(segmentations),len(SurfaceDistanceMeasures.__members__.items())))  \n",
    "\n",
    "# Compute the evaluation criteria\n",
    "\n",
    "# Note that for the overlap measures filter, because we are dealing with a single label we \n",
    "# use the combined, all labels, evaluation measures without passing a specific label to the methods.\n",
    "overlap_measures_filter = sitk.LabelOverlapMeasuresImageFilter()\n",
    "\n",
    "hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()\n",
    "\n",
    "# Use the absolute values of the distance map to compute the surface distances (distance map sign, outside or inside \n",
    "# relationship, is irrelevant)\n",
    "label = 1\n",
    "reference_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(reference_segmentation, squaredDistance=False))\n",
    "reference_surface = sitk.LabelContour(reference_segmentation)\n",
    "\n",
    "statistics_image_filter = sitk.StatisticsImageFilter()\n",
    "# Get the number of pixels in the reference surface by counting all pixels that are 1.\n",
    "statistics_image_filter.Execute(reference_surface)\n",
    "num_reference_surface_pixels = int(statistics_image_filter.GetSum()) \n",
    "\n",
    "for i, seg in enumerate(segmentations):\n",
    "    # Overlap measures\n",
    "    overlap_measures_filter.Execute(reference_segmentation, seg)\n",
    "    overlap_results[i,OverlapMeasures.jaccard.value] = overlap_measures_filter.GetJaccardCoefficient()\n",
    "    overlap_results[i,OverlapMeasures.dice.value] = overlap_measures_filter.GetDiceCoefficient()\n",
    "    overlap_results[i,OverlapMeasures.volume_similarity.value] = overlap_measures_filter.GetVolumeSimilarity()\n",
    "    overlap_results[i,OverlapMeasures.false_negative.value] = overlap_measures_filter.GetFalseNegativeError()\n",
    "    overlap_results[i,OverlapMeasures.false_positive.value] = overlap_measures_filter.GetFalsePositiveError()\n",
    "    # Hausdorff distance\n",
    "    hausdorff_distance_filter.Execute(reference_segmentation, seg)\n",
    "    surface_distance_results[i,SurfaceDistanceMeasures.hausdorff_distance.value] = hausdorff_distance_filter.GetHausdorffDistance()\n",
    "    # Symmetric surface distance measures\n",
    "    segmented_distance_map = sitk.Abs(sitk.SignedMaurerDistanceMap(seg, squaredDistance=False))\n",
    "    segmented_surface = sitk.LabelContour(seg)\n",
    "        \n",
    "    # Multiply the binary surface segmentations with the distance maps. The resulting distance\n",
    "    # maps contain non-zero values only on the surface (they can also contain zero on the surface)\n",
    "    seg2ref_distance_map = reference_distance_map*sitk.Cast(segmented_surface, sitk.sitkFloat32)\n",
    "    ref2seg_distance_map = segmented_distance_map*sitk.Cast(reference_surface, sitk.sitkFloat32)\n",
    "        \n",
    "    # Get the number of pixels in the segmented surface by counting all pixels that are 1.\n",
    "    statistics_image_filter.Execute(segmented_surface)\n",
    "    num_segmented_surface_pixels = int(statistics_image_filter.GetSum())\n",
    "    \n",
    "    # Get all non-zero distances and then add zero distances if required.\n",
    "    seg2ref_distance_map_arr = sitk.GetArrayViewFromImage(seg2ref_distance_map)\n",
    "    seg2ref_distances = list(seg2ref_distance_map_arr[seg2ref_distance_map_arr!=0]) \n",
    "    seg2ref_distances = seg2ref_distances + \\\n",
    "                        list(np.zeros(num_segmented_surface_pixels - len(seg2ref_distances)))\n",
    "    ref2seg_distance_map_arr = sitk.GetArrayViewFromImage(ref2seg_distance_map)\n",
    "    ref2seg_distances = list(ref2seg_distance_map_arr[ref2seg_distance_map_arr!=0]) \n",
    "    ref2seg_distances = ref2seg_distances + \\\n",
    "                        list(np.zeros(num_reference_surface_pixels - len(ref2seg_distances)))\n",
    "        \n",
    "    all_surface_distances = seg2ref_distances + ref2seg_distances\n",
    "    \n",
    "    surface_distance_results[i,SurfaceDistanceMeasures.mean_surface_distance.value] = np.mean(all_surface_distances)\n",
    "    surface_distance_results[i,SurfaceDistanceMeasures.median_surface_distance.value] = np.median(all_surface_distances)\n",
    "    surface_distance_results[i,SurfaceDistanceMeasures.std_surface_distance.value] = np.std(all_surface_distances)\n",
    "    surface_distance_results[i,SurfaceDistanceMeasures.max_surface_distance.value] = np.max(all_surface_distances)\n",
    "\n",
    "import pandas as pd\n",
    "from IPython.display import display, HTML \n",
    "\n",
    "# Graft our results matrix into pandas data frames \n",
    "overlap_results_df = pd.DataFrame(data=overlap_results, index=[\"before registration\", \"after registration\"], \n",
    "                                  columns=[name for name, _ in OverlapMeasures.__members__.items()]) \n",
    "surface_distance_results_df = pd.DataFrame(data=surface_distance_results, index=[\"before registration\", \"after registration\"], \n",
    "                                  columns=[name for name, _ in SurfaceDistanceMeasures.__members__.items()]) \n",
    "\n",
    "# Display the data as HTML tables and graphs\n",
    "display(HTML(overlap_results_df.to_html(float_format=lambda x: '%.3f' % x)))\n",
    "display(HTML(surface_distance_results_df.to_html(float_format=lambda x: '%.3f' % x)))\n",
    "overlap_results_df.plot(kind='bar', rot=1).legend(bbox_to_anchor=(1.6,0.9))\n",
    "surface_distance_results_df.plot(kind='bar', rot=1).legend(bbox_to_anchor=(1.6,0.9));   "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "<a href=\"07_registration_application.ipynb\"><h2 align=right>Next &raquo;</h2></a>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
